{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training a machine learning model with scikit-learn ([video #4](https://www.youtube.com/watch?v=RlQuVL6-qe8&list=PL5-da3qGB5ICeMbQuqbbCOQWcS6OYBr5A&index=4))\n",
    "\n",
    "Created by [Data School](http://www.dataschool.io/). Watch all 9 videos on [YouTube](https://www.youtube.com/playlist?list=PL5-da3qGB5ICeMbQuqbbCOQWcS6OYBr5A). Download the notebooks from [GitHub](https://github.com/justmarkham/scikit-learn-videos).\n",
    "\n",
    "**Note:** This notebook uses Python 3.6 and scikit-learn 0.19.1. The original notebook (shown in the video) used Python 2.7 and scikit-learn 0.16, and can be downloaded from the [archive branch](https://github.com/justmarkham/scikit-learn-videos/tree/archive)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Agenda\n",
    "\n",
    "- What is the **K-nearest neighbors** classification model?\n",
    "- What are the four steps for **model training and prediction** in scikit-learn?\n",
    "- How can I apply this pattern to **other machine learning models**?"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Reviewing the iris dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "        <iframe\n",
       "            width=\"300\"\n",
       "            height=\"200\"\n",
       "            src=\"http://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data\"\n",
       "            frameborder=\"0\"\n",
       "            allowfullscreen\n",
       "        ></iframe>\n",
       "        "
      ],
      "text/plain": [
       "<IPython.lib.display.IFrame at 0x10fb4e4a8>"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from IPython.display import IFrame\n",
    "IFrame('http://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data', width=300, height=200)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- 150 **observations**\n",
    "- 4 **features** (sepal length, sepal width, petal length, petal width)\n",
    "- **Response** variable is the iris species\n",
    "- **Classification** problem since response is categorical\n",
    "- More information in the [UCI Machine Learning Repository](http://archive.ics.uci.edu/ml/datasets/Iris)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## K-nearest neighbors (KNN) classification"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "1. Pick a value for K.\n",
    "2. Search for the K observations in the training data that are \"nearest\" to the measurements of the unknown iris.\n",
    "3. Use the most popular response value from the K nearest neighbors as the predicted response value for the unknown iris."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Example training data\n",
    "\n",
    "![Training data](images/04_knn_dataset.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### KNN classification map (K=1)\n",
    "\n",
    "![1NN classification map](images/04_1nn_map.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### KNN classification map (K=5)\n",
    "\n",
    "![5NN classification map](images/04_5nn_map.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "*Image Credits: [Data3classes](http://commons.wikimedia.org/wiki/File:Data3classes.png#/media/File:Data3classes.png), [Map1NN](http://commons.wikimedia.org/wiki/File:Map1NN.png#/media/File:Map1NN.png), [Map5NN](http://commons.wikimedia.org/wiki/File:Map5NN.png#/media/File:Map5NN.png) by Agor153. Licensed under CC BY-SA 3.0*"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Loading the data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# import load_iris function from datasets module\n",
    "from sklearn.datasets import load_iris\n",
    "\n",
    "# save \"bunch\" object containing iris dataset and its attributes\n",
    "iris = load_iris()\n",
    "\n",
    "# store feature matrix in \"X\"\n",
    "X = iris.data\n",
    "\n",
    "# store response vector in \"y\"\n",
    "y = iris.target"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(150, 4)\n",
      "(150,)\n"
     ]
    }
   ],
   "source": [
    "# print the shapes of X and y\n",
    "print(X.shape)\n",
    "print(y.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## scikit-learn 4-step modeling pattern"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Step 1:** Import the class you plan to use"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.neighbors import KNeighborsClassifier"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Step 2:** \"Instantiate\" the \"estimator\"\n",
    "\n",
    "- \"Estimator\" is scikit-learn's term for model\n",
    "- \"Instantiate\" means \"make an instance of\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "knn = KNeighborsClassifier(n_neighbors=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- Name of the object does not matter\n",
    "- Can specify tuning parameters (aka \"hyperparameters\") during this step\n",
    "- All parameters not specified are set to their defaults"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',\n",
      "           metric_params=None, n_jobs=1, n_neighbors=1, p=2,\n",
      "           weights='uniform')\n"
     ]
    }
   ],
   "source": [
    "print(knn)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Step 3:** Fit the model with data (aka \"model training\")\n",
    "\n",
    "- Model is learning the relationship between X and y\n",
    "- Occurs in-place"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',\n",
       "           metric_params=None, n_jobs=1, n_neighbors=1, p=2,\n",
       "           weights='uniform')"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "knn.fit(X, y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Step 4:** Predict the response for a new observation\n",
    "\n",
    "- New observations are called \"out-of-sample\" data\n",
    "- Uses the information it learned during the model training process"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([2])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "knn.predict([[3, 5, 4, 2]])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- Returns a NumPy array\n",
    "- Can predict for multiple observations at once"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([2, 1])"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X_new = [[3, 5, 4, 2], [5, 4, 3, 2]]\n",
    "knn.predict(X_new)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Using a different value for K"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([1, 1])"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# instantiate the model (using the value K=5)\n",
    "knn = KNeighborsClassifier(n_neighbors=5)\n",
    "\n",
    "# fit the model with data\n",
    "knn.fit(X, y)\n",
    "\n",
    "# predict the response for new observations\n",
    "knn.predict(X_new)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Using a different classification model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([2, 0])"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# import the class\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "\n",
    "# instantiate the model (using the default parameters)\n",
    "logreg = LogisticRegression()\n",
    "\n",
    "# fit the model with data\n",
    "logreg.fit(X, y)\n",
    "\n",
    "# predict the response for new observations\n",
    "logreg.predict(X_new)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Resources\n",
    "\n",
    "- [Nearest Neighbors](http://scikit-learn.org/stable/modules/neighbors.html) (user guide), [KNeighborsClassifier](http://scikit-learn.org/stable/modules/generated/sklearn.neighbors.KNeighborsClassifier.html) (class documentation)\n",
    "- [Logistic Regression](http://scikit-learn.org/stable/modules/linear_model.html#logistic-regression) (user guide), [LogisticRegression](http://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html) (class documentation)\n",
    "- [Videos from An Introduction to Statistical Learning](http://www.dataschool.io/15-hours-of-expert-machine-learning-videos/)\n",
    "    - Classification Problems and K-Nearest Neighbors (Chapter 2)\n",
    "    - Introduction to Classification (Chapter 4)\n",
    "    - Logistic Regression and Maximum Likelihood (Chapter 4)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Comments or Questions?\n",
    "\n",
    "- Email: <kevin@dataschool.io>\n",
    "- Website: http://dataschool.io\n",
    "- Twitter: [@justmarkham](https://twitter.com/justmarkham)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style>\n",
       "    @font-face {\n",
       "        font-family: \"Computer Modern\";\n",
       "        src: url('http://mirrors.ctan.org/fonts/cm-unicode/fonts/otf/cmunss.otf');\n",
       "    }\n",
       "    div.cell{\n",
       "        width: 90%;\n",
       "/*        margin-left:auto;*/\n",
       "/*        margin-right:auto;*/\n",
       "    }\n",
       "    ul {\n",
       "        line-height: 145%;\n",
       "        font-size: 90%;\n",
       "    }\n",
       "    li {\n",
       "        margin-bottom: 1em;\n",
       "    }\n",
       "    h1 {\n",
       "        font-family: Helvetica, serif;\n",
       "    }\n",
       "    h4{\n",
       "        margin-top: 12px;\n",
       "        margin-bottom: 3px;\n",
       "       }\n",
       "    div.text_cell_render{\n",
       "        font-family: Computer Modern, \"Helvetica Neue\", Arial, Helvetica, Geneva, sans-serif;\n",
       "        line-height: 145%;\n",
       "        font-size: 130%;\n",
       "        width: 90%;\n",
       "        margin-left:auto;\n",
       "        margin-right:auto;\n",
       "    }\n",
       "    .CodeMirror{\n",
       "            font-family: \"Source Code Pro\", source-code-pro,Consolas, monospace;\n",
       "    }\n",
       "/*    .prompt{\n",
       "        display: None;\n",
       "    }*/\n",
       "    .text_cell_render h5 {\n",
       "        font-weight: 300;\n",
       "        font-size: 16pt;\n",
       "        color: #4057A1;\n",
       "        font-style: italic;\n",
       "        margin-bottom: 0.5em;\n",
       "        margin-top: 0.5em;\n",
       "        display: block;\n",
       "    }\n",
       "\n",
       "    .warning{\n",
       "        color: rgb( 240, 20, 20 )\n",
       "        }\n",
       "</style>\n",
       "<script>\n",
       "    MathJax.Hub.Config({\n",
       "                        TeX: {\n",
       "                           extensions: [\"AMSmath.js\"]\n",
       "                           },\n",
       "                tex2jax: {\n",
       "                    inlineMath: [ ['$','$'], [\"\\\\(\",\"\\\\)\"] ],\n",
       "                    displayMath: [ ['$$','$$'], [\"\\\\[\",\"\\\\]\"] ]\n",
       "                },\n",
       "                displayAlign: 'center', // Change this to 'center' to center equations.\n",
       "                \"HTML-CSS\": {\n",
       "                    styles: {'.MathJax_Display': {\"margin\": 4}}\n",
       "                }\n",
       "        });\n",
       "</script>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from IPython.core.display import HTML\n",
    "def css_styling():\n",
    "    styles = open(\"styles/custom.css\", \"r\").read()\n",
    "    return HTML(styles)\n",
    "css_styling()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
